﻿#include "mathutils.h"
#include <iostream>
#include <fstream>
#include <sstream>
#include <algorithm>
#include <numeric>
#include <QTextStream>
#include <QFile>
#include <QDebug>
#include <QtMath>

MathUtils::MathUtils()
{
}

QList<double> MathUtils::LinearInterpolation(QList<double> xAxis, const QList<QPointF> &dataPoints)
{
    QList<double> ret;
    for(int i=0;i<xAxis.count();i++)
    {
        ret.append(linearInterpolation(xAxis.at(i),dataPoints));
    }
    return ret;
}

double MathUtils::linearInterpolation(double x_interpolate, const QList<QPointF> &dataPoints)
{
    int n = dataPoints.size();
    if (n < 2)
        return 0.0;

    for (int i = 0; i < n - 1; ++i)
    {
        if (x_interpolate >= dataPoints[i].x() && x_interpolate <= dataPoints[i + 1].x())
        {
            double x1 = dataPoints[i].x();
            double x2 = dataPoints[i + 1].x();
            double y1 = dataPoints[i].y();
            double y2 = dataPoints[i + 1].y();

            return y1 + (x_interpolate - x1) * (y2 - y1) / (x2 - x1);
        }
    }

    return 0.0; // Default value if interpolation point is out of range
}

vector<double> MathUtils::LoadMat2(string filename, int &row,int &column)
{
    vector<vector<double>> data=LoadMat(filename);
    vector<double> ret;
    row = data.size();
    column = data.at(0).size();
    for (size_t i = 0; i < data.at(0).size(); i++)
    {
        for (size_t j = 0; j < data.size(); j++)
        {
            ret.push_back(data[j][i]);
        }
    }
    return ret;
}

vector<vector<double> > MathUtils::LoadMat(string filename)
{
    std::vector<std::vector<double>> data;

    std::ifstream file(filename);
    if (!file.is_open()) {
        std::cout << "Failed to open file: " << filename << std::endl;
        return data;
    }

    std::string line;
    while (std::getline(file, line)) {
        std::vector<double> row;
        std::stringstream ss(line);
        std::string cell;
        while (std::getline(ss, cell, ',')) {
            double value = std::stod(cell);
            row.push_back(value);
        }

        data.push_back(row);
    }
    file.close();
    return data;
}

vector<int> MathUtils::FindIndex(vector<string> currentElement, vector<string> allElement)
{
    vector<int> ret;
    for (size_t i = 0; i < currentElement.size(); i++)
    {
        auto it = std::find(allElement.begin(), allElement.end(), currentElement.at(i));
        if (it != allElement.end())
        {
            int index = std::distance(allElement.begin(), it);
            ret.push_back(index);
        }
    }
    return ret;
}

QList<int> MathUtils::FindIndex(QList<QString> currentElement, QList<QString> allElement)
{
   QList<int> ret;
    for (int i = 0; i < currentElement.size(); i++)
    {
        auto it = std::find(allElement.begin(), allElement.end(), currentElement.at(i));
        if (it != allElement.end())
        {
            int index = std::distance(allElement.begin(), it);
            ret.append(index);
        }
    }
    return ret;
}

vector<string> MathUtils::Split(string source,char splitChar)
{
    vector<string> ret;
    istringstream iss(source);
    string item;
    while (getline(iss, item, splitChar))
    {
        ret.push_back(item);
    }
    return ret;
}

vector<double> MathUtils::ModifyXs(vector<double> oldXs, int itemLength, vector<double> ks, vector<double> bs, vector<int> modifyIndex)
{
    vector<double> ret;

    int modelxsSize = oldXs.size();
    int column = modifyIndex.size();
    qDebug()<<"modelxsSize:"<<modelxsSize<<"column"<<column;
    //modeldatas.modelxs
    for (int i = 0; i < modelxsSize; i++)
    {
        int a = 0;
        for (int j = 0; j < column; j++)
        {
            if (i >= modifyIndex[j] * itemLength && i < (modifyIndex[j] + 1)*itemLength)//判断是否为校正元素的模型
            {
                if (i == modelxsSize-1)//也是截距项边界条件
                {
                    if (oldXs.at(i) != 0.0)
                    {
                        ret.push_back(oldXs.at(i)*ks[j] + bs[j]);
                    }
                    else
                    {
                        ret.push_back(oldXs.at(i));
                    }
                }
                else if (oldXs.at(i + 1) == 0.0 && oldXs.at(i) != 0.0) //判断是否为截距项
                {
                    ret.push_back(oldXs.at(i)*ks[j] + bs[j]);
                }
                else
                {
                    ret.push_back(oldXs.at(i)*ks[j]);
                }
                a = a + 1;
                break;
            }
        }
        if (a == 0)
        {
            ret.push_back(oldXs.at(i)); //添加modelxs
        }
    }
    return ret;
}

double MathUtils::CalcCorrelation(vector<double> x, vector<double> y)
{
    int n = x.size();
    if (n != y.size() || n == 0)
        return 0.0;

    double sumX = 0.0;
    double sumY = 0.0;
    double sumXY = 0.0;
    double sumXSquare = 0.0;
    double sumYSquare = 0.0;

    for (int i = 0; i < n; ++i)
    {
        sumX += x[i];
        sumY += y[i];
        sumXY += x[i] * y[i];
        sumXSquare += x[i] * x[i];
        sumYSquare += y[i] * y[i];
    }

    double numerator = n * sumXY - sumX * sumY;
    double denominator = qSqrt((n * sumXSquare - sumX * sumX) * (n * sumYSquare - sumY * sumY));

    if (denominator == 0.0)
        return 0.0;

    return numerator / denominator;
}

double MathUtils::CalcR2(vector<double> x, vector<double> y)
{
    size_t n = x.size();
    if (n != y.size() || n == 0)
    {
        return 0.0;  // 错误处理，例如返回0或抛出异常
    }

    double sumObserved = std::accumulate(x.begin(), x.end(), 0.0);
    double sumSquaredObserved = std::inner_product(x.begin(), x.end(), x.begin(), 0.0);

    double sumPredicted = std::accumulate(y.begin(), y.end(), 0.0);
    double sumSquaredPredicted = std::inner_product(y.begin(), y.end(), y.begin(), 0.0);

    double sumProduct = std::inner_product(x.begin(), x.end(), y.begin(), 0.0);

    double numerator = n * sumProduct - sumObserved * sumPredicted;
    double denominator = std::sqrt((n * sumSquaredObserved - sumObserved * sumObserved) *
                                   (n * sumSquaredPredicted - sumPredicted * sumPredicted));

    if (denominator == 0.0)
    {
        return 0.0;  // 错误处理，例如返回0或抛出异常
    }

    double r2 = (numerator * numerator) / (denominator * denominator);
    return r2;
}

QVector<QVector<QString>> MathUtils::ReadCsvFile(const QString &filePath)
{
    QVector<QVector<QString>> data;

    QFile file(filePath);
    if (!file.open(QIODevice::ReadOnly | QIODevice::Text))
    {
        // 文件打开失败，可以在此处添加适当的错误处理逻辑
        return data;
    }

    QTextStream in(&file);
    while (!in.atEnd())
    {
        QString line = in.readLine();
        QVector<QString> row = line.split(',').toVector();

        QVector<QString> trimRow;
        for(QString item:row)
        {
            trimRow.append(item.trimmed());
        }
        data.append(trimRow);
    }

    file.close();
    return data;
}

QVector<QVector<QString> > MathUtils::SwapRowColumn(const QVector<QVector<QString> > &matrix)
{
    QVector<QVector<QString>> swappedMatrix;

    int numRows = matrix.size();
    int numCols = (numRows > 0) ? matrix[0].size() : 0;

    swappedMatrix.resize(numCols);

    for (int i = 0; i < numCols; ++i)
    {
        swappedMatrix[i].resize(numRows);

        for (int j = 0; j < numRows; ++j)
        {
            swappedMatrix[i][j] = matrix[j][i];
        }
    }

    return swappedMatrix;
}

QVector<double> MathUtils::ConvertToDoubleVector(const QVector<QString> &vec)
{
    QVector<double> doubleList;

    for (const QString& str : vec)
    {
        bool ok;
        double value = str.toDouble(&ok);
        if (ok)
        {
            doubleList.append(value);
        }
        else
        {
            qDebug() << "Invalid double value: " << str;
        }
    }
    return doubleList;
}

double MathUtils::Average(vector<double> values)
{
    if (values.empty())
    {
        throw std::runtime_error("Vector is empty.");
    }

    double sum = std::accumulate(values.begin(), values.end(), 0.0);
    double average = sum / values.size();
    return average;
}
